{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow 2.0+ Low Level APIs Convert Example\n",
    "\n",
    "This example demonstrates the workflow to build a model using\n",
    "TensorFlow 2.0+ low-level APIs and convert it to Core ML \n",
    "`.mlmodel` format using the `coremltools.converters.tensorflow` converter.\n",
    "For more example, refer `test_tf_2x.py` file.\n",
    "\n",
    "Note: \n",
    "\n",
    "- This notebook was tested with following dependencies:\n",
    "\n",
    "```\n",
    "tensorflow==2.0.0\n",
    "coremltools==3.1\n",
    "```\n",
    "\n",
    "- Models from TensorFlow 2.0+ is supported only for `minimum_ios_deployment_target>=13`.\n",
    "You can also use `tfcoreml.convert()` instead of \n",
    "`coremltools.converters.tensorflow.convert()` to convert your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W1101 14:02:33.174557 4762860864 __init__.py:74] TensorFlow version 2.0.0 detected. Last version known to be fully compatible is 1.14.0 .\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "3.1\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import coremltools\n",
    "\n",
    "print(tf.__version__)\n",
    "print(coremltools.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Low-Level APIs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W1101 14:02:33.537978 4762860864 deprecation.py:506] From /Volumes/data/venv-py36/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    }
   ],
   "source": [
    "# construct a toy model with low level APIs\n",
    "root = tf.train.Checkpoint()\n",
    "root.v1 = tf.Variable(3.)\n",
    "root.v2 = tf.Variable(2.)\n",
    "root.f = tf.function(lambda x: root.v1 * root.v2 * x)\n",
    "\n",
    "# save the model\n",
    "saved_model_dir = './tf_model'\n",
    "input_data = tf.constant(1., shape=[1, 1])\n",
    "to_save = root.f.get_concrete_function(input_data)\n",
    "tf.saved_model.save(root, saved_model_dir, to_save)\n",
    "\n",
    "tf_model = tf.saved_model.load(saved_model_dir)\n",
    "concrete_func = tf_model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 assert nodes deleted\n",
      "['Func/StatefulPartitionedCall/input/_2:0', 'StatefulPartitionedCall/mul/ReadVariableOp:0', 'statefulpartitionedcall_args_1:0', 'Func/StatefulPartitionedCall/input/_3:0', 'StatefulPartitionedCall/mul:0', 'StatefulPartitionedCall/ReadVariableOp:0', 'statefulpartitionedcall_args_2:0']\n",
      "6 nodes deleted\n",
      "0 nodes deleted\n",
      "0 nodes deleted\n",
      "2 identity nodes deleted\n",
      "0 disconnected nodes deleted\n",
      "[SSAConverter] Converting function main ...\n",
      "[SSAConverter] [1/3] Converting op type: 'Placeholder', name: 'x', output_shape: (1, 1).\n",
      "[SSAConverter] [2/3] Converting op type: 'Const', name: 'StatefulPartitionedCall/mul'.\n",
      "[SSAConverter] [3/3] Converting op type: 'Mul', name: 'Identity', output_shape: (1, 1).\n"
     ]
    }
   ],
   "source": [
    "# convert model into Core ML format\n",
    "model = coremltools.converters.tensorflow.convert(\n",
    "    [concrete_func],\n",
    "    inputs={'x': (1, 1)},\n",
    "    outputs=['Identity']\n",
    ")\n",
    "\n",
    "assert isinstance(model, coremltools.models.MLModel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Control Flow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct a TensorFlow 2.0+ model with tf.function()\n",
    "\n",
    "@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])\n",
    "def control_flow(x):\n",
    "    if x <= 0:\n",
    "        return 0.\n",
    "    else:\n",
    "        return x * 3.\n",
    "\n",
    "to_save = tf.Module()\n",
    "to_save.control_flow = control_flow\n",
    "\n",
    "saved_model_dir = './tf_model'\n",
    "tf.saved_model.save(to_save, saved_model_dir)\n",
    "tf_model = tf.saved_model.load(saved_model_dir)\n",
    "concrete_func = tf_model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 assert nodes deleted\n",
      "['PartitionedCall/cond/then/_2/Identity_1:0', 'PartitionedCall/LessEqual/y:0', 'PartitionedCall/cond/else/_3/mul/y:0', 'Func/PartitionedCall/cond/then/_2/output/_14:0', 'PartitionedCall/cond/then/_2/Const_1:0']\n",
      "2 nodes deleted\n",
      "Fixing cond at merge location: PartitionedCall/cond/output/_9\n",
      "In an IFF node  fp32  !=  tensor[fp32,1]\n",
      "0 nodes deleted\n",
      "0 nodes deleted\n",
      "2 identity nodes deleted\n",
      "0 disconnected nodes deleted\n",
      "[SSAConverter] Converting function main ...\n",
      "[SSAConverter] [1/7] Converting op type: 'Placeholder', name: 'x', output_shape: (1,).\n",
      "[SSAConverter] [2/7] Converting op type: 'Const', name: 'PartitionedCall/LessEqual/y'.\n",
      "[SSAConverter] [3/7] Converting op type: 'Const', name: 'Func/PartitionedCall/cond/then/_2/output/_14'.\n",
      "[SSAConverter] [4/7] Converting op type: 'Const', name: 'PartitionedCall/cond/else/_3/mul/y'.\n",
      "[SSAConverter] [5/7] Converting op type: 'LessEqual', name: 'PartitionedCall/LessEqual', output_shape: (1,).\n",
      "[SSAConverter] [6/7] Converting op type: 'Mul', name: 'PartitionedCall/cond/else/_3/mul', output_shape: (1,).\n",
      "[SSAConverter] [7/7] Converting op type: 'iff', name: 'Identity'.\n"
     ]
    }
   ],
   "source": [
    "# convert model into Core ML format\n",
    "model = coremltools.converters.tensorflow.convert(\n",
    "    [concrete_func],\n",
    "    inputs={'x': (1,)},\n",
    "    outputs=['Identity']\n",
    ")\n",
    "\n",
    "assert isinstance(model, coremltools.models.MLModel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# try with some sample inputs\n",
    "\n",
    "inputs = [-3.7, 6.17, 0.0, 1984., -5.]\n",
    "for data in inputs:\n",
    "    out1 = to_save.control_flow(data).numpy()\n",
    "    out2 = model.predict({'x': np.array([data])})['Identity']\n",
    "    np.testing.assert_array_almost_equal(out1, out2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using `tf.keras` Subclassing APIs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(tf.keras.Model):\n",
    "    def __init__(self):\n",
    "        super(MyModel, self).__init__()\n",
    "        self.dense1 = tf.keras.layers.Dense(4)\n",
    "        self.dense2 = tf.keras.layers.Dense(5)\n",
    "\n",
    "    @tf.function\n",
    "    def call(self, input_data):\n",
    "        return self.dense2(self.dense1(input_data))\n",
    "\n",
    "keras_model = MyModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 assert nodes deleted\n",
      "['my_model/StatefulPartitionedCall/args_3:0', 'Func/my_model/StatefulPartitionedCall/input/_2:0', 'Func/my_model/StatefulPartitionedCall/StatefulPartitionedCall/input/_11:0', 'my_model/StatefulPartitionedCall/args_4:0', 'Func/my_model/StatefulPartitionedCall/input/_4:0', 'Func/my_model/StatefulPartitionedCall/StatefulPartitionedCall/input/_12:0', 'my_model/StatefulPartitionedCall/args_2:0', 'my_model/StatefulPartitionedCall/StatefulPartitionedCall/dense_1/StatefulPartitionedCall/MatMul/ReadVariableOp:0', 'Func/my_model/StatefulPartitionedCall/StatefulPartitionedCall/dense_1/StatefulPartitionedCall/input/_25:0', 'Func/my_model/StatefulPartitionedCall/input/_3:0', 'Func/my_model/StatefulPartitionedCall/StatefulPartitionedCall/input/_13:0', 'Func/my_model/StatefulPartitionedCall/input/_5:0', 'Func/my_model/StatefulPartitionedCall/StatefulPartitionedCall/input/_10:0', 'Func/my_model/StatefulPartitionedCall/StatefulPartitionedCall/dense_1/StatefulPartitionedCall/input/_24:0', 'my_model/StatefulPartitionedCall/args_1:0', 'Func/my_model/StatefulPartitionedCall/StatefulPartitionedCall/dense/StatefulPartitionedCall/input/_18:0', 'my_model/StatefulPartitionedCall/StatefulPartitionedCall/dense/StatefulPartitionedCall/BiasAdd/ReadVariableOp:0', 'my_model/StatefulPartitionedCall/StatefulPartitionedCall/dense/StatefulPartitionedCall/MatMul/ReadVariableOp:0', 'my_model/StatefulPartitionedCall/StatefulPartitionedCall/dense_1/StatefulPartitionedCall/BiasAdd/ReadVariableOp:0', 'Func/my_model/StatefulPartitionedCall/StatefulPartitionedCall/dense/StatefulPartitionedCall/input/_19:0']\n",
      "16 nodes deleted\n",
      "0 nodes deleted\n",
      "0 nodes deleted\n",
      "[Op Fusion] fuse_bias_add() deleted 4 nodes.\n",
      "2 identity nodes deleted\n",
      "2 disconnected nodes deleted\n",
      "[SSAConverter] Converting function main ...\n",
      "[SSAConverter] [1/3] Converting op type: 'Placeholder', name: 'input_1', output_shape: (4, 4).\n",
      "[SSAConverter] [2/3] Converting op type: 'MatMul', name: 'my_model/StatefulPartitionedCall/StatefulPartitionedCall/dense/StatefulPartitionedCall/MatMul', output_shape: (4, 4).\n",
      "[SSAConverter] [3/3] Converting op type: 'MatMul', name: 'Identity', output_shape: (4, 5).\n"
     ]
    }
   ],
   "source": [
    "inputs = np.random.rand(4, 4)\n",
    "\n",
    "# subclassed model can only be saved as SavedModel format\n",
    "keras_model._set_inputs(inputs)\n",
    "saved_model_dir = './tf_model_subclassing'\n",
    "keras_model.save(saved_model_dir, save_format='tf')\n",
    "# convert and validate\n",
    "model = coremltools.converters.tensorflow.convert(\n",
    "    saved_model_dir,\n",
    "    inputs={'input_1': (4, 4)},\n",
    "    outputs=['Identity']\n",
    ")\n",
    "assert isinstance(model, coremltools.models.MLModel)\n",
    "# verify the prediction matches\n",
    "keras_prediction = keras_model.predict(inputs)\n",
    "prediction = model.predict({'input_1': inputs})['Identity']\n",
    "np.testing.assert_array_equal(keras_prediction.shape, prediction.shape)\n",
    "np.testing.assert_almost_equal(keras_prediction.flatten(), prediction.flatten(), decimal=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
